Skip to main content

Tree

  • Tree problems usually just spin around DFS or BFS and utilizes some specific tree properties so it's important to be able to implement these 2 algo while blindfolded

Overview

Traversal

DFS

  • The idea is usually to use recursion and recursively traverse the left and right subtree while performing some operation

  • Sometimes it's helpful to use another params in the recursive function to keep track of some data (eg: path sum, current node path, etc)

  • Preorder Traversal

    • Root => Left => Right
def preorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

left = self.preorderTraversal(root.left)
right = self.preorderTraversal(root.right)

return [root.val] + left + right
  • Inorder Traversal
    • Left => Root => Right
def inorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

left = self.inorderTraversal(root.left)
right = self.inorderTraversal(root.right)

return left + [root.val]+ right

https://leetcode.com/problems/construct-binary-tree-from-preorder-and-inorder-traversal

  • Important to understand the property of Inorder traversal and Preorder traversal
def buildTree(self, preorder: List[int], inorder: List[int]) -> Optional[TreeNode]:
if not preorder or not inorder:
return None

root = TreeNode(preorder[0])
# We can use rootIndex of inorder because
# preorder: root + left + right
# inorder: left + root + right
# So we can use the rootIndex to split the preorder and inoder to 2 subarrays
rootIndex = inorder.index(root.val)
root.left = self.buildTree(preorder[1:rootIndex+1], inorder[:rootIndex+1])
root.right = self.buildTree(preorder[rootIndex+1:], inorder[rootIndex+1:])

return root
  • Postorder Traversal
    • Left => Right => Root
def postorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

left = self.postorderTraversal(root.left)
right = self.postorderTraversal(root.right)

return left + right + [root.val]

https://leetcode.com/problems/construct-binary-tree-from-preorder-and-postorder-traversal

  • Important to understand the property of Postorder traversal and Preorder traversal
def constructFromPrePost(self, preorder: List[int], postorder: List[int]) -> Optional[TreeNode]:
if not preorder or not postorder:
return None

root = TreeNode(preorder[0])
if len(preorder) > 1:
leftVal = preorder[1]
leftValIndex = postorder.index(leftVal)
root.left = self.constructFromPrePost(
preorder[1 : 1 + leftValIndex + 1], postorder[: leftValIndex + 1]
)
root.right = self.constructFromPrePost(
preorder[1 + leftValIndex + 1 :], postorder[leftValIndex + 1 : -1]
)

return root
  • Example:
# https://leetcode.com/problems/path-sum/
def hasPathSum(self, root: Optional[TreeNode], targetSum: int) -> bool:
def getSum(node, curSum):
if not node:
return False
if not node.left and not node.right and curSum + node.val == targetSum:
return True

left = getSum(node.left, curSum + node.val)
right = getSum(node.right, curSum + node.val)

return left or right

return getSum(root, 0)

https://leetcode.com/problems/boundary-of-binary-tree

  • Actual implementation is not bad, good practice for DFS thinking
  • The hard part is understand WTF the description is about
 def boundaryOfBinaryTree(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

def getLeft(node):
if not node:
return []
if not node.left and not node.right:
return []
current = [node.val]
if node.left:
return current + getLeft(node.left)
else:
return current + getLeft(node.right)

def getRight(node):
if not node:
return []
if not node.left and not node.right:
return []
current = [node.val]
if node.right:
return getRight(node.right) + current
else:
return getRight(node.left) + current

leaves = []
def getLeaves(node):
if not node:
return
if not node.left and not node.right:
leaves.append(node.val)
return

getLeaves(node.left)
getLeaves(node.right)

left = getLeft(root.left)
current = [root.val]
right = getRight(root.right)
if root.left or root.right:
getLeaves(root)

return current + left + leaves + right

BFS

  • The idea is to use queue to keep track of what to traverse next
  • Use when
    • Level order traversal
      • When problem asks about relationships between nodes at the same height/depth
      • When you need to find the minimum depth or shortest path
    • Parent relationship
      • When you need to keep track of what is the parent of current node (any maybe compare to other node of the same level)
# https://leetcode.com/problems/binary-tree-level-order-traversal/
def levelOrder(self, root: Optional[TreeNode]) -> List[List[int]]:
res = []
if not root:
return res
queue = [root]

while queue:
cur_level = []
cur_len = len(queue)
for i in range(cur_len):
node = queue.pop(0)
cur_level.append(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
res.append(cur_level)

return res

https://leetcode.com/problems/maximum-level-sum-of-a-binary-tree

 def maxLevelSum(self, root: Optional[TreeNode]) -> int:
maxVal = float('-inf')
maxLevel = -1

queue = [(root)]
level = 1

while queue:
curLen = len(queue)
curSum = 0
for _ in range(curLen):
node = queue.pop(0)
curSum += node.val
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
if curSum > maxVal:
maxVal = curSum
maxLevel = level
level += 1

return maxLevel

https://leetcode.com/problems/cousins-in-binary-tree

  def isCousins(self, root: Optional[TreeNode], x: int, y: int) -> bool:
if not root:
return False

queue = [(root, None, 0)]
xDepth, xParent = -1, None
yDepth, yParent = -1, None

while queue:
curLen = len(queue)
for i in range(curLen):
node, parent, depth = queue.pop(0)
if node.val == x:
xParent, xDepth = parent, depth
elif node.val == y:
yParent, yDepth = parent, depth

if xParent and yParent:
break

if node.left:
queue.append((node.left, node, depth + 1))
if node.right:
queue.append((node.right, node, depth + 1))

return xDepth == yDepth and xParent != yParent

https://leetcode.com/problems/cousins-in-binary-tree-ii

  • More advanced version here
  • The trick here is that we have to optimize the calculation of every other nodes which do not share the same parent
    • Pre-computation: node.val = total - node.val - sum(cousins) => We can pre-calculate this
    • HashTable
  def replaceValueInTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
if not root:
return
queue = [(root, None)]

while queue:
curLen = len(queue)
level = []
for _ in range(curLen):
node, parent = queue.pop(0)
level.append((node, parent))
if node.left:
queue.append((node.left, node))
if node.right:
queue.append((node.right, node))

negate = {}
for i in range(len(level)):
node, parent = level[i]
negate[i] = node.val
if i > 0 and parent == level[i-1][1]:
negate[i-1] += node.val
negate[i] += level[i-1][0].val

total = sum([node.val for node, _ in level])

for i in range(len(level)):
node, _ = level[i]
node.val = total - negate[i]

return root

Problems

https://leetcode.com/problems/recover-a-tree-from-preorder-traversal

  • More of a string parsing rather than pure Tree DFS question
  • The trick here is we can only add the node if the current depth match the number of dashes, so we have to also keep track of current depth while in DFS
def recoverFromPreorder(self, traversal: str) -> Optional[TreeNode]:
def helper(start, depth):
if not traversal:
return None, start
dashes = 0
i = start
while i < len(traversal) and traversal[i] == "-":
dashes += 1
i += 1
if dashes != depth:
return None, start

num = 0
while i < len(traversal) and traversal[i] != "-":
num = num * 10 + int(traversal[i])
i += 1

node = TreeNode(num)
node.left, nextStart = helper(i, depth + 1)
node.right, nextStart2 = helper(nextStart, depth + 1)

return node, nextStart2

return helper(0, 0)[0]

Variants

Binary Search Tree

  • Definition: A tree where each node has at most two children, and for any node

    • All values in the left subtree are less than the node's value
    • All values in the right subtree are greater than the node's value
    • In-order traversal yields sorted order
  • Efficient for search, insert, and delete operations when balanced: O(log n)

  • These are some of the fundamental problems that demonstrate BST properties

def sortedArrayToBST(self, nums: List[int]) -> Optional[TreeNode]:
if not nums:
return

mid = len(nums) // 2
root = TreeNode(nums[mid])
root.left = self.sortedArrayToBST(nums[:mid])
root.right = self.sortedArrayToBST(nums[mid+1:])

return root

https://leetcode.com/problems/construct-binary-search-tree-from-preorder-traversal

  • Good question to notice BST properties
def bstFromPreorder(self, preorder: List[int]) -> Optional[TreeNode]:
if not preorder:
return
root = TreeNode(preorder[0])
left, right = 1, len(preorder) - 1
while left <= right:
mid = left + (right - left) // 2
if preorder[mid] > root.val:
right = mid-1
else:
left = mid + 1

leftHalf = self.bstFromPreorder(preorder[1:left])
rightHalf = self.bstFromPreorder(preorder[left:])

root.left = leftHalf
root.right = rightHalf

return root

https://leetcode.com/problems/validate-binary-search-tree

     def isValidBST(self, root: Optional[TreeNode]) -> bool:
def helper(node, lower, upper):
if not node:
return True
if node.val >= upper or node.val <= lower:
return False

return helper(node.left, lower, node.val) and helper(node.right, node.val, upper)

return helper(root, float('-inf'), float('inf'))

https://leetcode.com/problems/insert-into-a-binary-search-tree

 def insertIntoBST(self, root: Optional[TreeNode], val: int) -> Optional[TreeNode]:
if not root:
return TreeNode(val)

if val > root.val:
root.right = self.insertIntoBST(root.right, val)

if val < root.val:
root.left = self.insertIntoBST(root.left, val)

return root

https://leetcode.com/problems/delete-node-in-a-bst

  • Important algorithm to memorize
  • 3 steps:
    1. Find the node to remove
    2. Replace the node by its successor (either smallest in right subtree or largest in left subtree)
    3. Remove the successor
def deleteNode(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:
def findSuccessor(node):
if node.left:
return findSuccessor(node.left)
return node

if not root:
return

if key > root.val:
root.right = self.deleteNode(root.right, key)
elif key < root.val:
root.left = self.deleteNode(root.left, key)
else:
if not root.left and not root.right:
return None
elif not root.left:
return root.right
elif not root.right:
return root.left
else:
successor = findSuccessor(root.right)
root.val = successor.val
root.right = self.deleteNode(root.right, root.val)

return root

Lowest Common Ancestor

https://leetcode.com/discuss/interview-question/6024811

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree

  def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
if not root or root.val == p.val or root.val == q.val:
return root

left = self.lowestCommonAncestor(root.left, p, q)
right = self.lowestCommonAncestor(root.right, p, q)

if left and right:
return root

return left or right

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-ii

  • Use the helper/dfs function to both find the common ancestor and the existence of each node
  • Maintain 2 variables to keep track of the existence of each node
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
foundP = False
foundQ = False

def helper(node):
nonlocal foundP
nonlocal foundQ
if not node:
return

left = helper(node.left)
right = helper(node.right)

if node.val == p.val:
foundP = True
return node
if node.val == q.val:
foundQ = True
return node

if left and right:
return node
return left or right

res = helper(root)
return res if foundP and foundQ else None

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-search-tree

# Recursive
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
if root.val == q.val or root.val == p.val:
return root
if (p.val < root.val < q.val) or (q.val < root.val < p.val):
return root

if p.val < root.val and q.val < root.val:
return self.lowestCommonAncestor(root.left, p, q)

if p.val > root.val and q.val > root.val:
return self.lowestCommonAncestor(root.right, p, q)

# Iterative
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
node = root
while node:
if node.val > p.val and node.val > q.val:
node = node.left
elif node.val < p. val and node.val < q.val:
node = node.right
elif (node.val >= p.val and node.val <= q.val) or (node.val <= p.val and node.val >= q.val):
return node

return None

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iii

  • Traverse 1 node all the way to the root
  • Traverse the other one, the first node that we meet that is visited by the previous traversal is the LCA
  def lowestCommonAncestor(self, p: 'Node', q: 'Node') -> 'Node':
visit = set()
while p:
visit.add(p)
p = p.parent

while q:
if q in visit:
return q
visit.add(q)
q = q.parent

return None

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iv

def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
nodes = set(nodes)

def dfs(node):
if not node or node in nodes:
return node
left = dfs(node.left)
right = dfs(node.right)

if left and right:
return node
return left or right

return dfs(root)

https://leetcode.com/problems/lowest-common-ancestor-of-deepest-leaves

   def lcaDeepestLeaves(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
def depth(node):
if not node:
return 0, node

leftDepth, lcaLeft = depth(node.left)
rightDepth, lcaRight = depth(node.right)

if leftDepth == rightDepth:
return leftDepth + 1, node
elif leftDepth > rightDepth:
return leftDepth + 1, lcaLeft
elif leftDepth < rightDepth:
return rightDepth + 1, lcaRight

return depth(root)[1]

Tree Diameter

https://leetcode.com/problems/diameter-of-binary-tree

def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
res = 0

def dfs(node):
nonlocal res
if not node:
return 0
left = dfs(node.left)
right = dfs(node.right)

res = max(res, left + right)

return 1 + max(left, right)

dfs(root)
return res

https://leetcode.com/problems/diameter-of-n-ary-tree

  • Same idea as before, just need to take sum of 2 longest path
 def diameter(self, root: 'Node') -> int:
res = 0

def dfs(node):
nonlocal res
if not node:
return
childs = []
maxChild, secondMaxChild = 0, 0
for child in node.children:
current = dfs(child)
if current >= maxChild:
secondMaxChild = maxChild
maxChild = current
elif secondMaxChild <= current < maxChild:
secondMaxChild = current
childs.append(current)

res = max(res, maxChild + secondMaxChild)

return 1 + max(childs) if childs else 1

dfs(root)
return res

Patterns

Tree as Graph

  • Sometimes a problem is given as a tree, but we actually want to treat (maybe convert) the tree as a [[9. Graph| graph]].
  • In a typical tree problem, we only need to
    • Move from parent to children (downward)
    • Process nodes in a specific order
    • Track information along a single path
  • These signals/characteristics in a problem should make you consider transforming the tree into a graph:

Non standard movement requirements

https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree

def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> List[int]:
def buildGraph(node, parent, graph):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
buildGraph(node.left, node, graph)
if node.right:
graph[node.val].append(node.right.val)
buildGraph(node.right, node, graph)

graph = collections.defaultdict(list)
buildGraph(root, None, graph)

visit = set()
res = []
queue = [(target.val, 0)]
visit.add(target.val)

while queue:
node, dist = queue.pop(0)
if dist > k:
continue
elif dist == k:
res.append(node)
else:
for nb in graph[node]:
if nb not in visit:
queue.append((nb, dist+1))
visit.add(nb)

return res

https://leetcode.com/problems/amount-of-time-for-binary-tree-to-be-infected/

def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
graph = collections.defaultdict(list)

def buildGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
buildGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
buildGraph(node.right, node)

buildGraph(root, None)
infected = set()
infected.add(start)
queue = [(start, 0)]
res = 0

while queue:
curLen = len(queue)
for i in range(curLen):
node, dist = queue.pop(0)
infected.add(node)
res = max(res, dist)
for nb in graph[node]:
if nb not in infected:
queue.append((nb, dist+1))

return res

https://leetcode.com/problems/step-by-step-directions-from-a-binary-tree-node-to-another


https://leetcode.com/problems/closest-leaf-in-a-binary-tree

  def findClosestLeaf(self, root: Optional[TreeNode], k: int) -> int:
graph = collections.defaultdict(list)
leaf = set()

def builGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if not node.left and not node.right:
leaf.add(node.val)
if node.left:
graph[node.val].append(node.left.val)
builGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
builGraph(node.right, node)

builGraph(root, None)
visit = set()
queue = [k]
visit.add(k)

while queue:
curLen = len(queue)
for _ in range(curLen):
curNode = queue.pop(0)
if curNode in leaf:
return curNode
for nb in graph[curNode]:
if nb not in visit:
visit.add(nb)
queue.append(nb)

return -1

https://leetcode.com/problems/minimum-time-to-collect-all-apples-in-a-tree

def minTime(self, n: int, edges: List[List[int]], hasApple: List[bool]) -> int:
graph = {i: [] for i in range(n)}

for a, b in edges:
graph[a].append(b)
graph[b].append(a)

def dfs(current, parent):
time = 0

for nb in graph[current]:
if nb == parent:
continue
nbTime = dfs(nb, current)
if nbTime or hasApple[nb]:
time += nbTime + 2
return time

return dfs(0, -1)

Relationship-base queries

  • When a problem asks about relationships that aren't purely hierarchal, consider a graph
    • Finding nodes at a specific distance
    • Finding the distance between any two nodes
    • Finding all nodes that can reach a target node
    • Finding the shortest path between nodes

https://leetcode.com/problems/find-distance-in-a-binary-tree

def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
graph = collections.defaultdict(list)

def builGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
builGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
builGraph(node.right, node)

builGraph(root, None)
visit = set()
queue = [(p, 0)]
visit.add(p)

while queue:
curLen = len(queue)
for _ in range(curLen):
node, dst = queue.pop(0)
visit.add(node)
if node == q:
return dst
for nb in graph[node]:
if nb not in visit:
queue.append((nb, dst + 1))

return -1

https://leetcode.com/problems/binary-tree-maximum-path-sum


Parent Access need

  • If you find yourself thinking "I need to know this node's parent" or "I need to move upward form this node", that's often a signal that a graph representation might be helpful.

https://leetcode.com/problems/find-all-the-lonely-nodes


https://leetcode.com/problems/find-nearest-right-node-in-binary-tree


https://leetcode.com/problems/lowest-common-ancestor-of-deepest-leaves


Misc

NOTES

  • An important property when dealing with Full Binary Tree is that it only has 2 childrens, from this we can calculate that:

    Binary Tree in Array Representation If a binary tree is represented as an array:

    1. **Index 0** represents the root node.
    2. For a node at index i:
    • **Left Child**: The left child is located at index 2i + 1.
    • **Right Child**: The right child is located at index 2i + 2.